import torch
import warnings
import copy as cp
import numpy as np
import sys
from PIL import Image
import torchvision
from ..base import BaseModel
from ...smp import isimg, listinstr, get_rank_and_world_size
from ...dataset import DATASET_TYPE
from huggingface_hub import snapshot_download


class PLLaVA(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False
    VIDEO_LLM = True

    def __init__(self, model_path='ermu2001/pllava-13b', dir_root=None, **kwargs):
        sys.path.append(dir_root)
        try:
            from tasks.eval.model_utils import load_pllava
        except:
            warnings.warn(
                'Please first install requirements and set the root path to use PLLaVA. \
                Follow the instructions at https://github.com/magic-research/PLLaVA.'
            )
            sys.exit(-1)

        rank, world_size = get_rank_and_world_size()
        self.nframe = 16
        self.use_lora = True
        self.lora_alpha = 4
        self.pooling_shape = (16, 12, 12)
        self.RESOLUTION = 672
        self.model_path = model_path
        # remind that, once the model goes larger (30B+) may cause the memory to be heavily used up. Even Tearing Nodes.
        weight_dir = snapshot_download(model_path)
        self.model, self.processor = load_pllava(
            model_path, num_frames=self.nframe, use_lora=self.use_lora,
            weight_dir=weight_dir, lora_alpha=self.lora_alpha, pooling_shape=self.pooling_shape
        )

        #  position embedding
        self.model = self.model.to(torch.device(rank))
        self.model = self.model.eval()

    def load_video(self, video_path, num_segments=8, resolution=336):
        from decord import VideoReader, cpu
        transforms = torchvision.transforms.Resize(size=resolution)
        vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
        num_frames = len(vr)
        frame_indices = self.get_index(num_frames, num_segments)
        images_group = list()
        for frame_index in frame_indices:
            img = Image.fromarray(vr[frame_index].asnumpy())
            images_group.append(transforms(img))
        return images_group

    def get_index(self, num_frames, num_segments):
        seg_size = float(num_frames - 1) / num_segments
        start = int(seg_size / 2)
        offsets = np.array([
            start + int(np.round(seg_size * idx)) for idx in range(num_segments)
        ])
        return offsets

    def generate_inner(self, message, dataset=None):
        from tasks.eval.model_utils import pllava_answer
        from tasks.eval.eval_utils import conv_templates

        question, video = self.message_to_promptvideo(message)

        img_list = self.load_video(video, num_segments=self.nframe, resolution=self.RESOLUTION)

        if self.model_path == 'ermu2001/pllava-34b':  # using slightly different conversation mode for 34b model
            if dataset in ['Video-MME', 'MVBench', 'MVBench_MP4']:  # MCQ dataset
                conv_mode = 'eval_mvbench_llavanext'
            else:  # VQA dataset
                conv_mode = 'eval_videoqa_llavanext'
        else:
            if dataset in ['Video-MME', 'MVBench', 'MVBench_MP4']:  # MCQ dataset
                conv_mode = 'eval_mvbench'
            else:  # VQA dataset
                conv_mode = 'eval_videoqabench'

        conv = conv_templates[conv_mode].copy()
        if dataset in ['MVBench', 'MVBench_MP4']:
            conv.user_query(message[1]['value'], message[0]['value'], message[-2]['value'], is_mm=True)
            conv.assistant_response(message[-1]['value'])
        else:
            conv.user_query(question, is_mm=True)
        llm_response, conv = pllava_answer(
            conv=conv, model=self.model, processor=self.processor,
            do_sample=False, img_list=img_list, max_new_tokens=256, print_res=False
        )
        if dataset in ['MVBench', 'MVBench_MP4']:
            llm_response = '(' + ''.join(llm_response.split(message[-1]['value'])[1:])
        return llm_response
